In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import random
random.seed(1100038344)
import survivalstan
import numpy as np
import pandas as pd
from stancache import stancache
from matplotlib import pyplot as plt


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/Cython/Distutils/old_build_ext.py:30: UserWarning: Cython.Distutils.old_build_ext does not properly handle dependencies and is deprecated.
  "Cython.Distutils.old_build_ext does not properly handle dependencies "
/home/jacquelineburos/.local/lib/python3.5/site-packages/IPython/html.py:14: ShimWarning: The `IPython.html` package has been deprecated. You should import from `notebook` instead. `IPython.html.widgets` has moved to `ipywidgets`.
  "`IPython.html.widgets` has moved to `ipywidgets`.", ShimWarning)
INFO:stancache.seed:Setting seed to 1245502385

The model

This style of modeling is often called the "piecewise exponential model", or PEM. It is the simplest case where we estimate the hazard of an event occurring in a time period as the outcome, rather than estimating the survival (ie, time to event) as the outcome.

Recall that, in the context of survival modeling, we have two models:

  1. A model for Survival ($S$), ie the probability of surviving to time $t$:

    $$ S(t)=Pr(Y > t) $$

  2. A model for the instantaneous hazard $\lambda$, ie the probability of a failure event occuring in the interval [$t$, $t+\delta t$], given survival to time $t$:

    $$ \lambda(t) = \lim_{\delta t \rightarrow 0 } \; \frac{Pr( t \le Y \le t + \delta t | Y > t)}{\delta t} $$

By definition, these two are related to one another by the following equation:

$$ \lambda(t) = \frac{-S'(t)}{S(t)} $$

Solving this, yields the following:

$$ S(t) = \exp\left( -\int_0^t \lambda(z) dz \right) $$

This model is called the piecewise exponential model because of this relationship between the Survival and hazard functions. It's piecewise because we are not estimating the instantaneous hazard; we are instead breaking time periods up into pieces and estimating the hazard for each piece.

There are several variations on the PEM model implemented in survivalstan. In this notebook, we are exploring just one of them.

A note about data formatting

When we model Survival, we typically operate on data in time-to-event form. In this form, we have one record per Subject (ie, per patient). Each record contains [event_status, time_to_event] as the outcome. This data format is sometimes called per-subject.

When we model the hazard by comparison, we typically operate on data that are transformed to include one record per Subject per time_period. This is called per-timepoint or long form.

All other things being equal, a model for Survival will typically estimate more efficiently (faster & smaller memory footprint) than one for hazard simply because the data are larger in the per-timepoint form than the per-subject form. The benefit of the hazard models is increased flexibility in terms of specifying the baseline hazard, time-varying effects, and introducing time-varying covariates.

In this example, we are demonstrating use of the standard PEM survival model, which uses data in long form. The stan code expects to recieve data in this structure.

Stan code for the model

This model is provided in survivalstan.models.pem_survival_model. Let's take a look at the stan code.


In [2]:
print(survivalstan.models.pem_survival_model)


/*  Variable naming:
 // dimensions
 N          = total number of observations (length of data)
 S          = number of sample ids
 T          = max timepoint (number of timepoint ids)
 M          = number of covariates
 
 // main data matrix (per observed timepoint*record)
 s          = sample id for each obs
 t          = timepoint id for each obs
 event      = integer indicating if there was an event at time t for sample s
 x          = matrix of real-valued covariates at time t for sample n [N, X]
 
 // timepoint-specific data (per timepoint, ordered by timepoint id)
 t_obs      = observed time since origin for each timepoint id (end of period)
 t_dur      = duration of each timepoint period (first diff of t_obs)
 
*/
// Jacqueline Buros Novik <jackinovik@gmail.com>

data {
  // dimensions
  int<lower=1> N;
  int<lower=1> S;
  int<lower=1> T;
  int<lower=0> M;
  
  // data matrix
  int<lower=1, upper=N> s[N];     // sample id
  int<lower=1, upper=T> t[N];     // timepoint id
  int<lower=0, upper=1> event[N]; // 1: event, 0:censor
  matrix[N, M] x;                 // explanatory vars
  
  // timepoint data
  vector<lower=0>[T] t_obs;
  vector<lower=0>[T] t_dur;
}
transformed data {
  vector[T] log_t_dur;  // log-duration for each timepoint
  int n_trans[S, T];  
  
  log_t_dur = log(t_obs);

  // n_trans used to map each sample*timepoint to n (used in gen quantities)
  // map each patient/timepoint combination to n values
  for (n in 1:N) {
      n_trans[s[n], t[n]] = n;
  }

  // fill in missing values with n for max t for that patient
  // ie assume "last observed" state applies forward (may be problematic for TVC)
  // this allows us to predict failure times >= observed survival times
  for (samp in 1:S) {
      int last_value;
      last_value = 0;
      for (tp in 1:T) {
          // manual says ints are initialized to neg values
          // so <=0 is a shorthand for "unassigned"
          if (n_trans[samp, tp] <= 0 && last_value != 0) {
              n_trans[samp, tp] = last_value;
          } else {
              last_value = n_trans[samp, tp];
          }
      }
  }  
}
parameters {
  vector[T] log_baseline_raw; // unstructured baseline hazard for each timepoint t
  vector[M] beta;         // beta for each covariate
  real<lower=0> baseline_sigma;
  real log_baseline_mu;
}
transformed parameters {
  vector[N] log_hazard;
  vector[T] log_baseline;     // unstructured baseline hazard for each timepoint t
  
  log_baseline = log_baseline_mu + log_baseline_raw + log_t_dur;
  
  for (n in 1:N) {
    log_hazard[n] = log_baseline[t[n]] + x[n,]*beta;
  }
}
model {
  beta ~ cauchy(0, 2);
  event ~ poisson_log(log_hazard);
  log_baseline_mu ~ normal(0, 1);
  baseline_sigma ~ normal(0, 1);
  log_baseline_raw ~ normal(0, baseline_sigma);
}
generated quantities {
  real log_lik[N];
  vector[T] baseline;
  real y_hat_time[S];      // predicted failure time for each sample
  int y_hat_event[S];      // predicted event (0:censor, 1:event)
  
  // compute raw baseline hazard, for summary/plotting
  baseline = exp(log_baseline_mu + log_baseline_raw);
  
  // prepare log_lik for loo-psis
  for (n in 1:N) {
      log_lik[n] = poisson_log_log(event[n], log_hazard[n]);
  }

  // posterior predicted values
  for (samp in 1:S) {
      int sample_alive;
      sample_alive = 1;
      for (tp in 1:T) {
        if (sample_alive == 1) {
              int n;
              int pred_y;
              real log_haz;
              
              // determine predicted value of this sample's hazard
              n = n_trans[samp, tp];
              log_haz = log_baseline[tp] + x[n,] * beta;
              
              // now, make posterior prediction of an event at this tp
              if (log_haz < log(pow(2, 30))) 
                  pred_y = poisson_log_rng(log_haz);
              else
                  pred_y = 9; 
              
              // summarize survival time (observed) for this pt
              if (pred_y >= 1) {
                  // mark this patient as ineligible for future tps
                  // note: deliberately treat 9s as events 
                  sample_alive = 0;
                  y_hat_time[samp] = t_obs[tp];
                  y_hat_event[samp] = 1;
              }
              
          }
      } // end per-timepoint loop
      
      // if patient still alive at max
      if (sample_alive == 1) {
          y_hat_time[samp] = t_obs[T];
          y_hat_event[samp] = 0;
      }
  } // end per-sample loop  
}

Simulate survival data

In order to demonstrate the use of this model, we will first simulate some survival data using survivalstan.sim.sim_data_exp_correlated. As the name implies, this function simulates data assuming a constant hazard throughout the follow-up time period, which is consistent with the Exponential survival function.

This function includes two simulated covariates by default (age and sex). We also simulate a situation where hazard is a function of the simulated value for sex.

We also center the age variable since this will make it easier to interpret estimates of the baseline hazard.


In [3]:
d = stancache.cached(
    survivalstan.sim.sim_data_exp_correlated,
    N=100,
    censor_time=20,
    rate_form='1 + sex',
    rate_coefs=[-3, 0.5],
)
d['age_centered'] = d['age'] - d['age'].mean()


INFO:stancache.stancache:sim_data_exp_correlated: cache_filename set to sim_data_exp_correlated.cached.N_100.censor_time_20.rate_coefs_54462717316.rate_form_1 + sex.pkl
INFO:stancache.stancache:sim_data_exp_correlated: Loading result from cache

Aside: In order to make this a more reproducible example, this code is using a file-caching function stancache.cached to wrap a function call to survivalstan.sim.sim_data_exp_correlated.

Explore simulated data

Here is what these data look like - this is per-subject or time-to-event form:


In [4]:
d.head()


Out[4]:
age sex rate true_t t event index age_centered
0 59 male 0.082085 20.948771 20.000000 False 0 4.18
1 58 male 0.082085 12.827519 12.827519 True 1 3.18
2 61 female 0.049787 27.018886 20.000000 False 2 6.18
3 57 female 0.049787 62.220296 20.000000 False 3 2.18
4 55 male 0.082085 10.462045 10.462045 True 4 0.18

It's not that obvious from the field names, but in this example "subjects" are indexed by the field index.

We can plot these data using lifelines, or the rudimentary plotting functions provided by survivalstan.


In [5]:
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t', label='female')
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t', label='male')
plt.legend()


Out[5]:
<matplotlib.legend.Legend at 0x7fc064221668>

Transform to long or per-timepoint form

Finally, since this is a PEM model, we transform our data to long or per-timepoint form.


In [6]:
dlong = stancache.cached(
    survivalstan.prep_data_long_surv,
    df=d, event_col='event', time_col='t'
)


INFO:stancache.stancache:prep_data_long_surv: cache_filename set to prep_data_long_surv.cached.df_33772694934.event_col_event.time_col_t.pkl
INFO:stancache.stancache:prep_data_long_surv: Loading result from cache

We now have one record per timepoint (distinct values of end_time) per subject (index, in the original data frame).


In [7]:
dlong.query('index == 1').sort_values('end_time')


Out[7]:
age sex rate true_t t event index age_centered key end_time end_failure
140 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 0.118611 False
81 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 0.196923 False
139 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 0.262114 False
149 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 0.641174 False
104 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 0.944220 False
136 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 1.105340 False
146 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 1.397562 False
86 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 1.476557 False
135 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 1.530035 False
103 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 2.111333 False
147 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 2.330953 False
83 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 2.357800 False
138 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 2.639054 False
113 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 2.724832 False
125 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 2.743388 False
142 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 3.015604 False
118 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 3.095814 False
108 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 3.471401 False
143 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 3.637968 False
126 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 3.792521 False
133 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 4.090998 False
128 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 4.613828 False
119 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 4.829138 False
117 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 4.856847 False
96 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 5.008202 False
95 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 5.084885 False
141 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 5.359748 False
144 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 6.434233 False
116 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 6.512257 False
90 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 6.688216 False
... ... ... ... ... ... ... ... ... ... ... ...
88 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 7.001683 False
127 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 7.157144 False
130 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 7.329006 False
153 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 7.351628 False
148 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 7.405822 False
105 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 7.417478 False
101 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 7.442196 False
131 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 7.561702 False
91 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 7.679609 False
112 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 8.228047 False
151 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 8.263575 False
106 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 8.456715 False
114 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 8.817222 False
82 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 9.244121 False
98 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 9.336164 False
109 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 9.344597 False
123 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 9.590623 False
92 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 9.731395 False
124 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 9.984362 False
121 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 10.159427 False
80 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 10.462045 False
93 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 10.787069 False
102 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 11.371130 False
85 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 11.540905 False
155 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 11.751679 False
97 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 12.145235 False
122 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 12.156584 False
115 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 12.157394 False
89 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 12.559011 False
79 58 male 0.082085 12.827519 12.827519 True 1 3.18 1 12.827519 True

63 rows × 11 columns

Fit stan model

Now, we are ready to fit our model using survivalstan.fit_stan_survival_model.

We pass a few parameters to the fit function, many of which are required. See ?survivalstan.fit_stan_survival_model for details.

Similar to what we did above, we are asking survivalstan to cache this model fit object. See stancache for more details on how this works. Also, if you didn't want to use the cache, you could omit the parameter FIT_FUN and survivalstan would use the standard pystan functionality.


In [8]:
testfit = survivalstan.fit_stan_survival_model(
    model_cohort = 'test model',
    model_code = survivalstan.models.pem_survival_model,
    df = dlong,
    sample_col = 'index',
    timepoint_end_col = 'end_time',
    event_col = 'end_failure',
    formula = '~ age_centered + sex',
    iter = 5000,
    chains = 4,
    seed = 9001,
    FIT_FUN = stancache.cached_stan_fit,
    )


INFO:stancache.stancache:Step 1: Get compiled model code, possibly from cache
INFO:stancache.stancache:StanModel: cache_filename set to anon_model.cython_0_25_1.model_code_49777972005.pystan_2_12_0_0.stanmodel.pkl
INFO:stancache.stancache:StanModel: Loading result from cache
INFO:stancache.stancache:Step 2: Get posterior draws from model, possibly from cache
INFO:stancache.stancache:sampling: cache_filename set to anon_model.cython_0_25_1.model_code_49777972005.pystan_2_12_0_0.stanfit.chains_4.data_31278094506.iter_5000.seed_9001.pkl
INFO:stancache.stancache:sampling: Starting execution
INFO:stancache.stancache:sampling: Execution completed (0:03:04.556861 elapsed)
INFO:stancache.stancache:sampling: Saving results to cache
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stancache/stancache.py:251: UserWarning: Pickling fit objects is an experimental feature!
The relevant StanModel instance must be pickled along with this fit object.
When unpickling the StanModel must be unpickled first.
  pickle.dump(res, open(cache_filepath, 'wb'), pickle.HIGHEST_PROTOCOL)
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/psis.py:228: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison
  elif sort == 'in-place':
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/psis.py:246: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future
  bs /= 3 * x[sort[np.floor(n/4 + 0.5) - 1]]
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/psis.py:262: RuntimeWarning: overflow encountered in exp
  np.exp(temp, out=temp)

Superficial review of convergence

We will note here some top-level summaries of posterior draws -- this is a minimal example so it's unlikely that this model converged very well.

In practice, you would want to do a lot more investigation of convergence issues, etc. For now the goal is to demonstrate the functionalities available here.

We can summarize posterior estimates for a single parameter, (e.g. the built-in Stan parameter lp__):


In [9]:
survivalstan.utils.print_stan_summary([testfit], pars='lp__')


            mean   se_mean         sd        2.5%         50%       97.5%      Rhat
lp__ -278.136728  5.000714  50.256551 -360.824357 -284.525363 -177.241899  1.023704

Or, for sets of parameters with the same name:


In [10]:
survivalstan.utils.print_stan_summary([testfit], pars='log_baseline_raw')


                          mean   se_mean        sd      2.5%       50%     97.5%      Rhat
log_baseline_raw[0]   0.018405  0.001410  0.141042 -0.266412  0.008231  0.349270  0.999859
log_baseline_raw[1]   0.018517  0.001413  0.141319 -0.263341  0.008995  0.342834  1.000190
log_baseline_raw[2]   0.018302  0.001414  0.141387 -0.257858  0.009377  0.347234  0.999925
log_baseline_raw[3]   0.016333  0.001371  0.137110 -0.253927  0.008481  0.330025  1.000333
log_baseline_raw[4]   0.011990  0.001368  0.136849 -0.259225  0.005920  0.325580  1.000145
log_baseline_raw[5]   0.013532  0.001370  0.137043 -0.259307  0.007282  0.324367  1.000157
log_baseline_raw[6]   0.011470  0.001354  0.135406 -0.266722  0.004746  0.319663  0.999911
log_baseline_raw[7]   0.012491  0.001347  0.134695 -0.260132  0.006712  0.316462  1.000008
log_baseline_raw[8]   0.011526  0.001393  0.139255 -0.270375  0.004809  0.331765  1.000029
log_baseline_raw[9]   0.007375  0.001363  0.136255 -0.279145  0.003548  0.308510  0.999796
log_baseline_raw[10]  0.005951  0.001356  0.135572 -0.269977  0.004103  0.304337  0.999750
log_baseline_raw[11]  0.005699  0.001387  0.138687 -0.286487  0.002771  0.312076  0.999869
log_baseline_raw[12]  0.004665  0.001376  0.137565 -0.283921  0.000912  0.311435  0.999783
log_baseline_raw[13]  0.007392  0.001347  0.134663 -0.270574  0.002433  0.308283  1.000123
log_baseline_raw[14]  0.004534  0.001342  0.134165 -0.285146  0.001556  0.292968  0.999952
log_baseline_raw[15]  0.005103  0.001340  0.134046 -0.277456  0.002392  0.301938  0.999905
log_baseline_raw[16]  0.003518  0.001378  0.137844 -0.277368  0.002277  0.307551  0.999839
log_baseline_raw[17]  0.002659  0.001332  0.133164 -0.275810  0.001411  0.295238  0.999671
log_baseline_raw[18]  0.003533  0.001351  0.135064 -0.277553  0.001838  0.292063  0.999870
log_baseline_raw[19]  0.002062  0.001383  0.138273 -0.287297 -0.000670  0.307131  1.000221
log_baseline_raw[20]  0.003071  0.001356  0.135590 -0.292247  0.000741  0.302829  0.999768
log_baseline_raw[21] -0.001834  0.001394  0.139369 -0.296805 -0.000915  0.287677  0.999744
log_baseline_raw[22] -0.002089  0.001341  0.134124 -0.286115 -0.000353  0.282904  0.999881
log_baseline_raw[23] -0.002234  0.001374  0.137426 -0.304963 -0.001630  0.291352  1.000080
log_baseline_raw[24] -0.000190  0.001343  0.134303 -0.283250 -0.000253  0.294825  0.999723
log_baseline_raw[25] -0.001287  0.001369  0.136858 -0.295716 -0.000231  0.287421  0.999792
log_baseline_raw[26] -0.002714  0.001372  0.137223 -0.308940 -0.001566  0.291775  0.999712
log_baseline_raw[27] -0.006300  0.001364  0.136440 -0.302099 -0.003365  0.270088  0.999816
log_baseline_raw[28] -0.006234  0.001344  0.134426 -0.304960 -0.002756  0.279475  1.000183
log_baseline_raw[29] -0.005409  0.001320  0.131976 -0.292328 -0.002348  0.274455  0.999847
log_baseline_raw[30] -0.007161  0.001320  0.132008 -0.300255 -0.003781  0.268071  0.999809
log_baseline_raw[31] -0.006813  0.001362  0.136225 -0.312780 -0.002422  0.278778  0.999783
log_baseline_raw[32] -0.005802  0.001366  0.136638 -0.308640 -0.002019  0.280386  1.000023
log_baseline_raw[33] -0.008136  0.001322  0.132165 -0.294911 -0.005311  0.267101  0.999791
log_baseline_raw[34] -0.005414  0.001347  0.134679 -0.301907 -0.001826  0.271617  0.999739
log_baseline_raw[35] -0.005833  0.001332  0.133184 -0.294655 -0.003908  0.264071  0.999886
log_baseline_raw[36] -0.005597  0.001333  0.133337 -0.292467 -0.002439  0.264261  0.999732
log_baseline_raw[37] -0.006969  0.001347  0.134747 -0.301223 -0.002385  0.271992  0.999948
log_baseline_raw[38] -0.004569  0.001330  0.133014 -0.291487 -0.003056  0.271717  1.000080
log_baseline_raw[39] -0.005369  0.001377  0.137687 -0.311480 -0.001002  0.278568  0.999943
log_baseline_raw[40] -0.004928  0.001316  0.131619 -0.293840 -0.002247  0.271888  0.999953
log_baseline_raw[41] -0.006589  0.001354  0.135430 -0.310065 -0.002537  0.275856  0.999951
log_baseline_raw[42] -0.005687  0.001369  0.136899 -0.308345 -0.002883  0.281115  0.999867
log_baseline_raw[43] -0.007082  0.001337  0.133696 -0.307673 -0.002852  0.269971  1.000021
log_baseline_raw[44] -0.006966  0.001380  0.138005 -0.305034 -0.004202  0.277748  0.999742
log_baseline_raw[45] -0.006925  0.001344  0.134400 -0.302803 -0.003980  0.271892  0.999753
log_baseline_raw[46] -0.007264  0.001343  0.134262 -0.300234 -0.005482  0.274120  0.999905
log_baseline_raw[47] -0.008346  0.001373  0.137330 -0.312680 -0.003998  0.273031  1.000193
log_baseline_raw[48] -0.006099  0.001325  0.132452 -0.296824 -0.002946  0.271214  0.999884
log_baseline_raw[49] -0.007845  0.001339  0.133894 -0.307824 -0.002242  0.268274  0.999954
log_baseline_raw[50] -0.006161  0.001353  0.135278 -0.310913 -0.002115  0.270951  1.000001
log_baseline_raw[51] -0.004106  0.001367  0.136722 -0.305186 -0.000375  0.280699  0.999947
log_baseline_raw[52] -0.006734  0.001337  0.133681 -0.307962 -0.003439  0.277566  0.999966
log_baseline_raw[53] -0.004135  0.001350  0.134955 -0.294451 -0.000209  0.275918  0.999780
log_baseline_raw[54] -0.006873  0.001353  0.135333 -0.303199 -0.002728  0.273141  1.000029
log_baseline_raw[55] -0.007751  0.001396  0.139585 -0.318518 -0.003133  0.285527  0.999913
log_baseline_raw[56] -0.007375  0.001338  0.133818 -0.307137 -0.002825  0.275709  0.999920
log_baseline_raw[57] -0.008380  0.001351  0.135141 -0.311733 -0.004041  0.267765  0.999744
log_baseline_raw[58] -0.003854  0.001357  0.135702 -0.297579 -0.001496  0.288015  0.999748
log_baseline_raw[59] -0.005154  0.001355  0.135456 -0.299541 -0.001827  0.272893  0.999782
log_baseline_raw[60] -0.005536  0.001351  0.135127 -0.309023 -0.002620  0.279765  0.999889
log_baseline_raw[61] -0.005031  0.001368  0.136847 -0.310338 -0.002176  0.283778  1.000248
log_baseline_raw[62] -0.004126  0.001335  0.133517 -0.292910 -0.001112  0.273674  0.999906
log_baseline_raw[63] -0.006565  0.001341  0.134096 -0.299221 -0.003177  0.267685  1.000108
log_baseline_raw[64] -0.005644  0.001361  0.136146 -0.307535 -0.002344  0.276061  0.999890
log_baseline_raw[65] -0.005163  0.001322  0.132204 -0.288991 -0.002578  0.277272  1.000024
log_baseline_raw[66] -0.003780  0.001344  0.134406 -0.298615 -0.002005  0.282805  0.999966
log_baseline_raw[67] -0.003824  0.001327  0.132728 -0.292247 -0.003482  0.279543  0.999709
log_baseline_raw[68] -0.007090  0.001373  0.137329 -0.304829 -0.004162  0.273670  0.999759
log_baseline_raw[69] -0.006359  0.001352  0.135188 -0.300134 -0.003754  0.273260  0.999711
log_baseline_raw[70] -0.004139  0.001339  0.133909 -0.299168 -0.001363  0.278754  0.999813
log_baseline_raw[71] -0.005135  0.001337  0.133750 -0.302788 -0.002349  0.277571  0.999863
log_baseline_raw[72] -0.004251  0.001329  0.132932 -0.295977 -0.003475  0.273927  0.999838
log_baseline_raw[73] -0.004634  0.001385  0.138541 -0.305831 -0.001851  0.290301  0.999982
log_baseline_raw[74] -0.004999  0.001328  0.132767 -0.298910 -0.002887  0.269649  0.999897
log_baseline_raw[75] -0.004575  0.001344  0.134395 -0.304516 -0.001487  0.278224  0.999981
log_baseline_raw[76] -0.003373  0.001372  0.137236 -0.301230 -0.001954  0.288494  0.999889
log_baseline_raw[77] -0.021396  0.001397  0.139665 -0.349297 -0.009878  0.254990  1.000517

It's also not uncommon to graphically summarize the Rhat values, to get a sense of similarity among the chains for particular parameters.


In [11]:
survivalstan.utils.plot_stan_summary([testfit], pars='log_baseline_raw')


Plot posterior estimates of parameters

We can use plot_coefs to summarize posterior estimates of parameters.

In this basic pem_survival_model, we estimate a parameter for baseline hazard for each observed timepoint which is then adjusted for the duration of the timepoint. For consistency, the baseline values are normalized to the unit time given in the input data. This allows us to compare hazard estimates across timepoints without having to know the duration of a timepoint. (in general, the duration-adjusted hazard paramters are suffixed with _raw whereas those which are unit-normalized do not have a suffix).

In this model, the baseline hazard is parameterized by two components -- there is an overall mean across all timepoints (log_baseline_mu) and some variance per timepoint (log_baseline_tp). The degree of variance is estimated from the data as log_baseline_sigma. All components have weak default priors. See the stan code above for details.

In this case, the model estimates a minimal degree of variance across timepoints, which is good given that the simulated data assumed a constant hazard over time.


In [12]:
survivalstan.utils.plot_coefs([testfit], element='baseline')


We can also summarize the posterior estimates for our beta coefficients. This is actually the default behavior of plot_coefs. Here we hope to see the posterior estimates of beta coefficients include the value we used for our simulation (0.5).


In [13]:
survivalstan.utils.plot_coefs([testfit])


Posterior predictive checking

Finally, survivalstan provides some utilities for posterior predictive checking.

The goal of posterior-predictive checking is to compare the uncertainty of model predictions to observed values.

We are not doing true out-of-sample predictions, but we are able to sanity-check our model's calibration. We expect approximately 5% of observed values to fall outside of their corresponding 95% posterior-predicted intervals.

By default, survivalstan's plot_pp_survival method will plot whiskers at the 2.5th and 97.5th percentile values, corresponding to 95% predicted intervals.


In [14]:
survivalstan.utils.plot_pp_survival([testfit], fill=False)
survivalstan.utils.plot_observed_survival(df=d, event_col='event', time_col='t', color='green', label='observed')
plt.legend()


Out[14]:
<matplotlib.legend.Legend at 0x7fbfc340de80>

We can also summarize and plot survival by our covariates of interest, provided they are included in the original dataframe provided to fit_stan_survival_model.


In [15]:
survivalstan.utils.plot_pp_survival([testfit], by='sex')


This plot can also be customized by a variety of aesthetic elements


In [16]:
survivalstan.utils.plot_pp_survival([testfit], by='sex', pal=['red', 'blue'])


Building up the plot semi-manually, for more customization

We can also access the utility methods within survivalstan.utils to more or less produce the same plot. This sequence is intended to both illustrate how the above-described plot was constructed, and expose some of the functionality in a more concrete fashion.

Probably the most useful element is being able to summarize & return posterior-predicted values to begin with:


In [17]:
ppsurv = survivalstan.utils.prep_pp_survival_data([testfit], by='sex')

Here are what these data look like:


In [18]:
ppsurv.head()


Out[18]:
iter model_cohort sex level_3 event_time survival
0 0 test model female 0 0.000000 1.000000
1 0 test model female 1 1.397562 1.000000
2 0 test model female 2 2.330953 0.974215
3 0 test model female 3 2.357800 0.955702
4 0 test model female 4 2.743388 0.913719

(Note that this itself is a summary of the posterior draws returned by survivalstan.utils.prep_pp_data. In this case, the survival stats are summarized by values of ['iter', 'model_cohort', by].

We can then call out to survivalstan.utils._plot_pp_survival_data to construct the plot. In this case, we overlay the posterior predicted intervals with observed values.


In [19]:
subplot = plt.subplots(1, 1)
survivalstan.utils._plot_pp_survival_data(ppsurv.query('sex == "male"').copy(),
                                          subplot=subplot, color='blue', alpha=0.5)
survivalstan.utils._plot_pp_survival_data(ppsurv.query('sex == "female"').copy(),
                                          subplot=subplot, color='red', alpha=0.5)
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t',
                                          color='red', label='female')
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t',
                                          color='blue', label='male')
plt.legend()


Out[19]:
<matplotlib.legend.Legend at 0x7fbfc304a588>

Use plotly to summarize posterior predicted values

First, we will precompute 50th and 95th posterior intervals for each observed timepoint, by group.


In [20]:
ppsummary = ppsurv.groupby(['sex','event_time'])['survival'].agg({
        '95_lower': lambda x: np.percentile(x, 2.5),
        '95_upper': lambda x: np.percentile(x, 97.5),
        '50_lower': lambda x: np.percentile(x, 25),
        '50_upper': lambda x: np.percentile(x, 75),
        'median': lambda x: np.percentile(x, 50),
    }).reset_index()
shade_colors = dict(male='rgba(0, 128, 128, {})', female='rgba(214, 12, 140, {})')
line_colors = dict(male='rgb(0, 128, 128)', female='rgb(214, 12, 140)')
ppsummary.sort_values(['sex', 'event_time'], inplace=True)

Next, we construct our graph "traces", consisting of 3 elements (solid line and two shaded areas) per observed group.


In [21]:
import plotly
import plotly.plotly as py
import plotly.graph_objs as go
plotly.offline.init_notebook_mode(connected=True)


INFO:requests.packages.urllib3.connectionpool:Starting new HTTPS connection (1): api.plot.ly

In [22]:
data5 = list()
for grp, grp_df in ppsummary.groupby('sex'):
    x = list(grp_df['event_time'].values)
    x_rev = x[::-1]
    y_upper = list(grp_df['50_upper'].values)
    y_lower = list(grp_df['50_lower'].values)
    y_lower = y_lower[::-1]
    y2_upper = list(grp_df['95_upper'].values)
    y2_lower = list(grp_df['95_lower'].values)
    y2_lower = y2_lower[::-1]
    y = list(grp_df['median'].values)
    my_shading50 = go.Scatter(
        x = x + x_rev,
        y = y_upper + y_lower,
        fill = 'tozerox',
        fillcolor = shade_colors[grp].format(0.3),
        line = go.Line(color = 'transparent'),
        showlegend = True,
        name = '{} - 50% CI'.format(grp),
    )
    my_shading95 = go.Scatter(
        x = x + x_rev,
        y = y2_upper + y2_lower,
        fill = 'tozerox',
        fillcolor = shade_colors[grp].format(0.1),
        line = go.Line(color = 'transparent'),
        showlegend = True,
        name = '{} - 95% CI'.format(grp),
    )
    my_line = go.Scatter(
        x = x,
        y = y,
        line = go.Line(color=line_colors[grp]),
        mode = 'lines',
        name = grp,
    )
    data5.append(my_line)    
    data5.append(my_shading50)
    data5.append(my_shading95)

Finally, we build a minimal layout structure to house our graph:


In [23]:
layout5 = go.Layout(
    yaxis=dict(
        title='Survival (%)',
        #zeroline=False,
        tickformat='.0%',
    ),
    xaxis=dict(title='Days since enrollment')
)

Here is our plot:


In [24]:
py.iplot(go.Figure(data=data5, layout=layout5), filename='survivalstan/pem_survival_model_ppsummary')


INFO:requests.packages.urllib3.connectionpool:Starting new HTTPS connection (1): plot.ly
Out[24]:

Note: this plot will not render in github, since github disables iframes. You can however view it in nbviewer or on plotly's website directly